/**
 * Copyright 2024 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import { z, type Genkit } from 'genkit';
import {
  GenerationCommonConfigSchema,
  getBasicUsageStats,
  modelRef,
  type CandidateData,
  type GenerateRequest,
  type ModelAction,
  type ModelInfo,
  type ModelReference,
} from 'genkit/model';
import type { GoogleAuth } from 'google-auth-library';
import type { PluginOptions } from './common/types.js';
import { predictModel, type PredictClient } from './predict.js';

/**
 * See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api.
 */
export const ImagenConfigSchema = GenerationCommonConfigSchema.extend({
  // TODO: Remove common config schema extension since Imagen models don't support
  // most of the common config parameters. Also, add more parameters like sampleCount
  // from the above reference.
  language: z
    .enum(['auto', 'en', 'es', 'hi', 'ja', 'ko', 'pt', 'zh-TW', 'zh', 'zh-CN'])
    .describe('Language of the prompt text.')
    .optional(),
  aspectRatio: z
    .enum(['1:1', '9:16', '16:9', '3:4', '4:3'])
    .describe('Desired aspect ratio of the output image.')
    .optional(),
  negativePrompt: z
    .string()
    .describe(
      'A description of what to discourage in the generated images. ' +
        'For example: "animals" (removes animals), "blurry" ' +
        '(makes the image clearer), "text" (removes text), or ' +
        '"cropped" (removes cropped images).'
    )
    .optional(),
  seed: z
    .number()
    .int()
    .min(1)
    .max(2147483647)
    .describe(
      'Controls the randomization of the image generation process. Use the ' +
        'same seed across requests to provide consistency, or change it to ' +
        'introduce variety in the response.'
    )
    .optional(),
  location: z
    .string()
    .describe('Google Cloud region e.g. us-central1.')
    .optional(),
  personGeneration: z
    .enum(['dont_allow', 'allow_adult', 'allow_all'])
    .describe('Control if/how images of people will be generated by the model.')
    .optional(),
  safetySetting: z
    .enum(['block_most', 'block_some', 'block_few', 'block_fewest'])
    .describe('Adds a filter level to safety filtering.')
    .optional(),
  addWatermark: z
    .boolean()
    .describe('Add an invisible watermark to the generated images.')
    .optional(),
  storageUri: z
    .string()
    .describe('Cloud Storage URI to store the generated images.')
    .optional(),
  mode: z
    .enum(['upscale'])
    .describe('Mode must be set for upscaling requests.')
    .optional(),
  /**
   * Describes the editing intention for the request.
   *
   * See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#edit_images_2 for details.
   */
  editConfig: z
    .object({
      editMode: z
        .enum([
          'inpainting-insert',
          'inpainting-remove',
          'outpainting',
          'product-image',
        ])
        .describe('Editing intention for the request.')
        .optional(),
      maskMode: z
        .object({
          maskType: z
            .enum(['background', 'foreground', 'semantic'])
            .describe(
              '"background" automatically generates a mask for all ' +
                'regions except the primary subject(s) of the image, ' +
                '"foreground" automatically generates a mask for the primary ' +
                'subjects(s) of the image. "semantic" segments one or more ' +
                'of the segmentation classes using class ID.'
            ),
          classes: z
            .array(z.number())
            .describe('List of class IDs for segmentation.')
            .length(5)
            .optional(),
        })
        .describe(
          'Prompts the model to generate a mask instead of you ' +
            'needing to provide one. Consequently, when you provide ' +
            'this parameter you can omit a mask object.'
        )
        .optional(),
      maskDilation: z
        .number()
        .describe('Dilation percentage of the mask provided.')
        .min(0.0)
        .max(1.0)
        .optional(),
      guidanceScale: z
        .number()
        .describe(
          'Controls how much the model adheres to the text prompt. ' +
            'Large values increase output and prompt alignment, but may ' +
            'compromise image quality. Suggested values are 0-9 ' +
            '(low strength), 10-20 (medium strength), 21+ (high strength).'
        )
        .optional(),
      productPosition: z
        .enum(['reposition', 'fixed'])
        .describe(
          'Defines whether the product should stay fixed or be ' +
            'repositioned.'
        )
        .optional(),
    })
    .passthrough()
    .optional(),
  upscaleConfig: z
    .object({
      upscaleFactor: z
        .enum(['x2', 'x4'])
        .describe('The factor to upscale the image.'),
    })
    .describe('Configuration for upscaling.')
    .optional(),
}).passthrough();

export const imagen2 = modelRef({
  name: 'vertexai/imagen2',
  info: {
    label: 'Vertex AI - Imagen2',
    versions: ['imagegeneration@006', 'imagegeneration@005'],
    supports: {
      media: false,
      multiturn: false,
      tools: false,
      systemRole: false,
      output: ['media'],
    },
  },
  version: 'imagegeneration@006',
  configSchema: ImagenConfigSchema,
});

export const imagen3 = modelRef({
  name: 'vertexai/imagen3',
  info: {
    label: 'Vertex AI - Imagen3',
    versions: ['imagen-3.0-generate-001'],
    supports: {
      media: true,
      multiturn: false,
      tools: false,
      systemRole: false,
      output: ['media'],
    },
  },
  version: 'imagen-3.0-generate-001',
  configSchema: ImagenConfigSchema,
});

export const imagen3Fast = modelRef({
  name: 'vertexai/imagen3-fast',
  info: {
    label: 'Vertex AI - Imagen3 Fast',
    versions: ['imagen-3.0-fast-generate-001'],
    supports: {
      media: false,
      multiturn: false,
      tools: false,
      systemRole: false,
      output: ['media'],
    },
  },
  version: 'imagen-3.0-fast-generate-001',
  configSchema: ImagenConfigSchema,
});

export const ACTUAL_IMAGEN_MODELS = {
  'imagen-3.0-generate-001': modelRef({
    name: 'vertexai/imagen-3.0-generate-001',
    info: {
      label: 'Vertex AI - imagen-3.0-generate-001',
      supports: {
        media: true,
        multiturn: false,
        tools: false,
        systemRole: false,
        output: ['media'],
      },
    },
    configSchema: ImagenConfigSchema,
  }),
  'imagen-3.0-fast-generate-001': modelRef({
    name: 'vertexai/imagen-3.0-fast-generate-001',
    info: {
      label: 'Vertex AI - imagen-3.0-fast-generate-001',
      supports: {
        media: true,
        multiturn: false,
        tools: false,
        systemRole: false,
        output: ['media'],
      },
    },
    configSchema: ImagenConfigSchema,
  }),
} as const;

export const SUPPORTED_IMAGEN_MODELS = {
  ...ACTUAL_IMAGEN_MODELS,
  // These are old, inconsistent model naming. Only here for backwards compatibility.
  imagen2: imagen2,
  imagen3: imagen3,
  'imagen3-fast': imagen3Fast,
};

function extractText(request: GenerateRequest) {
  return request.messages
    .at(-1)!
    .content.map((c) => c.text || '')
    .join('');
}

interface ImagenParameters {
  sampleCount?: number;
  aspectRatio?: string;
  negativePrompt?: string;
  seed?: number;
  language?: string;
  personGeneration?: string;
  safetySetting?: string;
  addWatermark?: boolean;
  storageUri?: string;
}

function toParameters(
  request: GenerateRequest<typeof ImagenConfigSchema>
): ImagenParameters {
  const out = {
    sampleCount: request.candidates ?? 1,
    ...request?.config,
  };

  for (const k in out) {
    if (!out[k]) delete out[k];
  }

  return out;
}

function extractMaskImage(request: GenerateRequest): string | undefined {
  return request.messages
    .at(-1)
    ?.content.find((p) => !!p.media && p.metadata?.type === 'mask')
    ?.media?.url.split(',')[1];
}

function extractBaseImage(request: GenerateRequest): string | undefined {
  return request.messages
    .at(-1)
    ?.content.find(
      (p) => !!p.media && (!p.metadata?.type || p.metadata?.type === 'base')
    )
    ?.media?.url.split(',')[1];
}

interface ImagenPrediction {
  bytesBase64Encoded: string;
  mimeType: string;
}

interface ImagenInstance {
  prompt: string;
  image?: { bytesBase64Encoded: string };
  mask?: { image?: { bytesBase64Encoded: string } };
}

export const GENERIC_IMAGEN_INFO = {
  label: `Vertex AI - Generic`,
  supports: {
    media: true,
    multiturn: true,
    tools: true,
    systemRole: true,
    output: ['media'],
  },
} as ModelInfo;

export function defineImagenModel(
  ai: Genkit,
  name: string,
  client: GoogleAuth,
  options: PluginOptions
): ModelAction {
  const modelName = `vertexai/${name}`;
  const model: ModelReference<z.ZodTypeAny> =
    SUPPORTED_IMAGEN_MODELS[name] ||
    modelRef({
      name: modelName,
      info: {
        ...GENERIC_IMAGEN_INFO,
        label: `Vertex AI - ${name}`,
      },
      configSchema: ImagenConfigSchema,
    });

  const predictClients: Record<
    string,
    PredictClient<ImagenInstance, ImagenPrediction, ImagenParameters>
  > = {};
  const predictClientFactory = (
    request: GenerateRequest<typeof ImagenConfigSchema>
  ): PredictClient<ImagenInstance, ImagenPrediction, ImagenParameters> => {
    const requestLocation = request.config?.location || options.location;
    if (!predictClients[requestLocation]) {
      predictClients[requestLocation] = predictModel<
        ImagenInstance,
        ImagenPrediction,
        ImagenParameters
      >(
        client,
        {
          ...options,
          location: requestLocation,
        },
        request.config?.version || model.version || name
      );
    }
    return predictClients[requestLocation];
  };

  return ai.defineModel(
    {
      name: modelName,
      ...model.info,
      configSchema: ImagenConfigSchema,
    },
    async (request) => {
      const instance: ImagenInstance = {
        prompt: extractText(request),
      };
      const baseImage = extractBaseImage(request);
      if (baseImage) {
        instance.image = { bytesBase64Encoded: baseImage };
      }
      const maskImage = extractMaskImage(request);
      if (maskImage) {
        instance.mask = {
          image: { bytesBase64Encoded: maskImage },
        };
      }

      const predictClient = predictClientFactory(request);
      const response = await predictClient([instance], toParameters(request));

      if (!response.predictions || response.predictions.length == 0) {
        throw new Error(
          'Model returned no predictions. Possibly due to content filters.'
        );
      }

      const candidates: CandidateData[] = response.predictions.map((p, i) => {
        const b64data = p.bytesBase64Encoded;
        const mimeType = p.mimeType;
        return {
          index: i,
          finishReason: 'stop',
          message: {
            role: 'model',
            content: [
              {
                media: {
                  url: `data:${mimeType};base64,${b64data}`,
                  contentType: mimeType,
                },
              },
            ],
          },
        };
      });
      return {
        candidates,
        usage: {
          ...getBasicUsageStats(request.messages, candidates),
          custom: { generations: candidates.length },
        },
        custom: response,
      };
    }
  );
}
